from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import os
import json
import argparse
import time

import numpy as np

import scipy.sparse as sparse
import data_utils as data
import datasets
import upper_bounds
import defenses
import iterative_attack
from upper_bounds import hinge_loss, hinge_grad, logistic_grad
from influence.influence.smooth_hinge import SmoothHinge
from influence.influence.BinarylogisticRegressionWithLBFGS import BinaryLogisticRegressionWithLBFGS
from influence.influence.dataset import DataSet
from tensorflow.contrib.learn.python.learn.datasets import base

import random 
import math

import tensorflow as tf

def prob_round(x):
    sign = np.sign(x)
    x = abs(x)
    is_up = random.random() < x-int(x)
    round_func = math.ceil if is_up else math.floor
    return sign * round_func(x)

def get_feature_bounds(X_train,X_test):
    # full_x = np.concatenate((X_train,X_test),axis=0)
    full_x = data.vstack(X_train,X_test)
    if args.dataset in ['mnist_17','mnist_38','mnist_69']:
        x_lims = [np.zeros(X_train.shape[1]),np.ones(X_train.shape[1])]
    elif args.dataset in ['enron','imdb']:                
        x_min = np.amin(full_x,axis=0)
        x_max = np.maximum(np.amax(full_x,axis=0),1)
        x_lims = [x_min,x_max]
    else:
        x_lims = [np.amin(full_x,axis=0),np.amax(full_x,axis=0)]
    return x_lims

def get_projection_fn_for_dataset(dataset_name, X, Y, use_slab, use_LP, percentile):
    if dataset_name in ['enron', 'imdb']:
        projection_fn = data.get_projection_fn(
            X, Y,
            sphere=True,
            slab=use_slab,
            non_negative=True,
            less_than_one=False,
            use_lp_rounding=use_LP,
            percentile=percentile)
    elif dataset_name in ['mnist_17']:
        projection_fn = data.get_projection_fn(
            X, Y,
            sphere=True,
            slab=use_slab,
            non_negative=True,
            less_than_one=True,
            use_lp_rounding=False,
            percentile=percentile)
    elif dataset_name in ['dogfish']:
        projection_fn = data.get_projection_fn(
            X, Y,
            sphere=True,
            slab=use_slab,
            non_negative=False,
            less_than_one=False,
            use_lp_rounding=False,
            percentile=percentile)
    return projection_fn

no_defense = True
fit_intercept = True
initial_learning_rate = 0.001
keep_probs = None
decay_epochs = [1000, 10000]
num_classes = 2
batch_size = 100
use_copy = True
use_LP = True
loss_type = 'normal_loss'

parser = argparse.ArgumentParser()
parser.add_argument('--em_iter', default=0)
parser.add_argument('--total_grad_iter', default=10000)
parser.add_argument('--use_slab', action='store_true')
parser.add_argument('--dataset', default='mnist_17',help='options: dogfish, mnist_38, mnist_69, mnist_49, enron, cifar10_05')
parser.add_argument('--percentile', default=90)
parser.add_argument('--epsilon', default=0.03)
parser.add_argument('--step_size', default=0.1)
parser.add_argument('--use_test', action="store_true")
parser.add_argument('--baseline', action="store_true") # means no LP, no copy, and no smooth
parser.add_argument('--baseline_smooth', action="store_true") # means no LP, no copy
parser.add_argument('--no_LP', action="store_true")
parser.add_argument('--timed', action="store_true")

parser.add_argument('--force_iter', action="store_true")
parser.add_argument('--no_force_refresh', action="store_true")
parser.add_argument('--rand_seed', default=1234, type=int, help='seed for random number generator')
parser.add_argument('--shard', type=int)
parser.add_argument('--model_type', default='svm',help='models: lr, svm')
parser.add_argument('--pos_frac', default=-1, type=float, help='specific fraction of positive poisons, 0.2892')
parser.add_argument('--extreme_start', action="store_true",help="use extreme point as the starting point")
parser.add_argument('--small_shard', default=0, type=int, help='in case smaller epsilon results are needed')
parser.add_argument('--fixed_eps', action="store_true",help="attack at fixed eps")


args = parser.parse_args()

dataset_name = args.dataset
use_slab = args.use_slab
# epsilon = float(args.epsilon)
step_size = float(args.step_size)
percentile = int(np.round(float(args.percentile)))
max_em_iter = int(np.round(float(args.em_iter)))
total_grad_iter = int(np.round(float(args.total_grad_iter)))
use_train = not args.use_test
baseline = args.baseline
baseline_smooth = args.baseline_smooth
no_LP = args.no_LP
timed = args.timed

extreme_start = args.extreme_start
pos_frac = args.pos_frac

if args.model_type == 'svm':
    temp = 0.001
else:
    temp = 0.0
force_iter = args.force_iter
force_refresh = not args.no_force_refresh
shard = args.shard
small_shard = args.small_shard
fixed_eps = args.fixed_eps

output_root = os.path.join(datasets.OUTPUT_FOLDER, dataset_name, 'influence_data','%s' % args.model_type)
train_output_root = os.path.join(datasets.OUTPUT_FOLDER, dataset_name, 'influence_train_data','%s' % args.model_type)
if not os.path.isdir(output_root):
    os.makedirs(output_root)
if not os.path.isdir(train_output_root):
    os.makedirs(train_output_root)

# datasets.safe_makedirs(output_root)
# datasets.safe_makedirs(train_output_root)

# print('epsilon: %s' % epsilon)
print('use_slab: %s' % use_slab)
print("use train ?: ",use_train)

discrete = True
partial_int = False
if dataset_name == 'adult':
    partial_int = True

if dataset_name == 'enron':
    if args.model_type == 'lr':
        weight_decay = 0.01
    else:
        weight_decay = 0.09
    discrete = True

if dataset_name == 'imdb':
    weight_decay = 0.01
    discrete = True

elif dataset_name in ['mnist_17','mnist_38','mnist_69','mnist_49']:
    weight_decay = 0.09 # 0.01
    discrete = False
elif dataset_name == 'dogfish':
    weight_decay = 0.09 # 1.1
    discrete = False
else:
    weight_decay = 0.09

if baseline:
    temp = 0
    assert dataset_name == 'enron'
    assert not baseline_smooth
    assert not use_train
    use_copy = False
    use_LP = False
    percentile = 80

if baseline_smooth:
    assert dataset_name == 'enron'
    assert not baseline
    assert not use_train
    use_copy = False
    use_LP = False
    percentile = 80

if no_LP:
    assert dataset_name == 'enron'
    use_LP = False
    percentile = 80

if max_em_iter == 0:
    num_grad_iter_per_em = total_grad_iter
else:
    assert total_grad_iter % max_em_iter == 0
    num_grad_iter_per_em = int(np.round(total_grad_iter / max_em_iter))

X_train, Y_train, X_test, Y_test = datasets.load_dataset(dataset_name)

if sparse.issparse(X_train):
    X_train = X_train.toarray()
if sparse.issparse(X_test):
    X_test = X_test.toarray()

if use_train:
    X_test = X_train
    Y_test = Y_train

epsilons = [0, 0.01, 0.02, 0.03, 0.05, 0.1, 0.15, 0.3]
small_epsilons = [0.001,0.002,0.003,0.005,0.007,0.009]
assert epsilons[0] == 0

if fixed_eps:
    target_epsilons = [0.03]
if shard == 0:
    target_epsilons = epsilons[1:3]
elif shard == 1:
    target_epsilons = epsilons[3:5]
elif shard == 2:
    target_epsilons = epsilons[5:8]
else:
    target_epsilons = [float(args.epsilon)]
    # raise ValueError('shard must be 0-3')

if small_shard > 0:
    epsilons = small_epsilons
    # assert dataset_name == 'cifar10_05'
    if small_shard == 1:
        target_epsilons = small_epsilons[0:2]
    elif small_shard == 2:
        target_epsilons = small_epsilons[2:4]
    elif small_shard == 3:
        target_epsilons = small_epsilons[4:6]


input_dim = X_train.shape[1]

if args.model_type == 'lr':
    train = DataSet(X_train, np.array((Y_train + 1) / 2, dtype=int))
    validation = None
    test = DataSet(X_test, np.array((Y_test + 1) / 2, dtype=int))
    data_sets = base.Datasets(train=train, validation=validation, test=test)
else:
    train = DataSet(X_train, Y_train)
    validation = None
    test = DataSet(X_test, Y_test)
    data_sets = base.Datasets(train=train, validation=validation, test=test)

print(set(Y_train),set(Y_test))
print(X_train.shape,Y_train.shape,X_test.shape,Y_test.shape)


if args.model_type == 'svm':
    model_type = 'smooth_hinge'
elif args.model_type == 'lr':
    model_type = 'lr'

for epsilon_idx, epsilon in enumerate(epsilons):
    np.random.seed(args.rand_seed)
    print('========== Epsilon %s ==========' % epsilon)

    if epsilon not in target_epsilons:
        continue 

    # get the model name
    if no_defense:
        model_name = '%s_%s_no_defense_step-%s_iter-%s_t-%s_eps-%s_wd-%s_rs-%s' % (model_type,dataset_name,\
                    step_size, total_grad_iter, temp, epsilon, weight_decay,args.rand_seed)
        if not force_refresh:
            model_name = '%s_%s_no_defense_no_update_step-%s_iter-%s_t-%s_eps-%s_wd-%s_rs-%s' % (model_type,dataset_name,\
                step_size, total_grad_iter, temp, epsilon, weight_decay,args.rand_seed)
        if use_train:
            model_name = model_name + '_use_train'
        else:
            model_name = model_name + '_use_test'
        if pos_frac > 0:
            model_name = model_name + '_%s' % pos_frac
        if extreme_start:
            model_name = model_name + '_extreme'
    else:
        model_name = '%s_%s_sphere-True_slab-%s_start-copy_lflip-True_step-%s_t-%s_eps-%s_wd-%s_rs-1' % (
                        model_type, dataset_name, use_slab,
                        step_size, temp, epsilon, weight_decay)
        if percentile != 90:
            model_name = model_name + '_percentile-%s' % percentile
        model_name += '_em-%s' % max_em_iter
        if baseline:
            model_name = model_name + '_baseline'
        if baseline_smooth:
            model_name = model_name + '_baseline-smooth'
        if no_LP:
            model_name = model_name + '_no-LP'
        if timed:
            model_name = model_name + '_timed'

        if use_train:
            model_name = model_name + '_use_train'
        else:
            model_name = model_name + '_use_test'

    if no_defense:
        feasible_flipped_mask = np.ones(X_train.shape[0],dtype=bool)
    else:
        class_map, centroids, centroid_vec, sphere_radii, slab_radii = data.get_data_params(
            X_train, Y_train, percentile=percentile)

        feasible_flipped_mask = iterative_attack.get_feasible_flipped_mask(
            X_train, Y_train,
            centroids,
            centroid_vec,
            sphere_radii,
            slab_radii,
            class_map,
            use_slab=use_slab)

    X_modified, Y_modified, indices_to_poison, copy_array = iterative_attack.init_gradient_attack_from_mask(
        X_train, Y_train,
        epsilon,
        feasible_flipped_mask,
        use_copy=use_copy,
        pos_frac=pos_frac,
        extreme_start=extreme_start,\
        output_root=output_root)

    if args.model_type == 'lr':
        Y_modified_use = np.array((Y_modified + 1)/2,dtype=int)
    else:
        Y_modified_use = Y_modified

    tf.reset_default_graph()

    if args.model_type == 'svm':
        model = SmoothHinge(
            input_dim=input_dim,
            temp=temp,
            weight_decay=weight_decay,
            use_bias=True,
            num_classes=num_classes,
            batch_size=batch_size,
            data_sets=data_sets,
            initial_learning_rate=initial_learning_rate,
            decay_epochs=None,
            mini_batch=False,
            train_dir=train_output_root,
            log_dir='log',
            model_name=model_name)
    elif args.model_type == 'lr':
        model = BinaryLogisticRegressionWithLBFGS(
            input_dim=input_dim,
            temp=temp,
            weight_decay=weight_decay,
            use_bias=True,
            model_rand_seed=24, 
            max_lbfgs_iter=10000,
            num_classes=num_classes,
            batch_size=batch_size,
            data_sets=data_sets,
            initial_learning_rate=initial_learning_rate,
            decay_epochs=None,
            mini_batch=False,
            train_dir=train_output_root,
            log_dir='log',
            model_name=model_name,
            rand_seed=args.rand_seed)

    model.update_train_x_y(X_modified, Y_modified_use)
    model.train()


    if timed:
        start_time = time.time()
    else:
        start_time = None

    num_em_iters = max(max_em_iter, 1)

    for em_iter in range(num_em_iters):

        print('\n\n##### EM iter %s #####' % em_iter)
        X_modified = model.data_sets.train.x
        Y_modified = model.data_sets.train.labels

        if no_defense:
            X_lims = get_feature_bounds(X_train,X_test)
            def projection_fn(X,Y,X_lims,discrete=False,adult=False):
                if sparse.issparse(X):
                    X = X.toarray()
                X_min, X_max = X_lims
                clipped_val = np.clip(X,X_min,X_max)
                if discrete:
                    if adult:
                        clipped_val[4:] = data.rround(clipped_val[4:])
                    else:
                        # randomly round to nearest integer
                        clipped_val = data.rround(clipped_val)
                return clipped_val
        else:
            if max_em_iter == 0:
                projection_fn = get_projection_fn_for_dataset(
                    dataset_name,
                    X_train,
                    Y_train,
                    use_slab,
                    use_LP,
                    percentile)
            else:
                projection_fn = get_projection_fn_for_dataset(
                    dataset_name,
                    X_modified,
                    Y_modified,
                    use_slab,
                    use_LP,
                    percentile)

        iterative_attack.iterative_attack(
            model,
            indices_to_poison=indices_to_poison,
            test_idx=None,
            test_description=None,
            step_size=step_size,
            num_iter=num_grad_iter_per_em,
            loss_type=loss_type,
            projection_fn=projection_fn,
            output_root=output_root,
            num_copies=copy_array,
            stop_after=2,
            start_time=start_time,
            no_defense=no_defense,
            X_lims=X_lims,
            force_iter=force_iter,
            force_refresh=force_refresh,
            epsilon=epsilon,
            discrete=discrete,
            partial_int = partial_int)
